{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save and Restore\n",
    "- Save the paramters of a model to disk\n",
    "- Restore the paramters of a model from the checkpoint\n",
    "- Customized model initialization\n",
    "- How to finetune a model\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save the paramters of your model to disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import ops\n",
    "import numpy as np\n",
    "tf.reset_default_graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model define\n",
    "def prelu(x, name = 'prelu'):\n",
    "    with tf.variable_scope(name):\n",
    "        alphas = tf.get_variable('alpha', x.get_shape()[-1], initializer=tf.constant_initializer(0.25), regularizer = l2_regularizer, dtype = tf.float32)\n",
    "    pos = tf.nn.relu(x)\n",
    "    neg = tf.multiply(alphas,(x - abs(x)) * 0.5)\n",
    "    return pos + neg\n",
    "\n",
    "def first_conv(input, num_output, name):\n",
    "  with tf.variable_scope(name):\n",
    "    zero_init = tf.zeros_initializer()\n",
    "    network = tf.layers.conv2d(input, num_output, kernel_size = [3, 3], strides = (2, 2), padding = 'same', kernel_initializer = xavier, bias_initializer = zero_init, kernel_regularizer = l2_regularizer, bias_regularizer = l2_regularizer)\n",
    "    network = prelu(network, name = name)\n",
    "    return network\n",
    "\n",
    "def block(input, name, num_output):\n",
    "    with tf.variable_scope(name):\n",
    "        network = tf.layers.conv2d(input, num_output, kernel_size = [3, 3], strides = [1, 1], padding = 'same', kernel_initializer = tf.random_normal_initializer(stddev=0.01), use_bias = False , kernel_regularizer = l2_regularizer)\n",
    "        network = prelu(network, name = 'name'+ '1')\n",
    "        network = tf.layers.conv2d(network, num_output, kernel_size = [3, 3], strides = [1, 1], padding = 'same', kernel_initializer = tf.random_normal_initializer(stddev=0.01), use_bias = False, kernel_regularizer = l2_regularizer)\n",
    "        network = prelu(network, name = 'name'+ '2')\n",
    "        network = tf.add(input, network)\n",
    "        return network\n",
    "l2_regularizer= tf.contrib.layers.l2_regularizer(1.0)\n",
    "xavier = tf.contrib.layers.xavier_initializer_conv2d() \n",
    "def get_shape(tensor):\n",
    "    static_shape = tensor.shape.as_list()\n",
    "    dynamic_shape = tf.unstack(tf.shape(tensor))\n",
    "    dims = [s[1] if s[0] is None else s[0] for s in zip(static_shape,dynamic_shape)]\n",
    "    return dims\n",
    "def infer(input,embedding_size=512):\n",
    "    with tf.variable_scope('conv1_'):\n",
    "        network = first_conv(input, 64, name = 'conv1')\n",
    "        network = block(network, 'conv1_23', 64)\n",
    "    with tf.variable_scope('conv2_'):\n",
    "        network = first_conv(network, 128, name = 'conv2')\n",
    "        network = block(network, 'conv2_23', 128)\n",
    "        network = block(network, 'conv2_45', 128)\n",
    "    with tf.variable_scope('conv3_'):\n",
    "        network = first_conv(network, 256, name = 'conv3')\n",
    "        network = block(network, 'conv3_23', 256)\n",
    "        network = block(network, 'conv3_45', 256)\n",
    "        network = block(network, 'conv3_67', 256)\n",
    "        network = block(network, 'conv3_89', 256)\n",
    "    with tf.variable_scope('conv4_'):\n",
    "        network = first_conv(network, 512, name = 'conv4')\n",
    "        network = block(network, 'conv4_23', 512)\n",
    "    with tf.variable_scope('feature'):\n",
    "        #BATCH_SIZE = network.get_shape()[0]\n",
    "        dims = get_shape(network)\n",
    "        print(dims)\n",
    "        #BATCH_SIZE = tf.shape(network)[0]\n",
    "        #feature = tf.layers.dense(tf.reshape(network,[BATCH_SIZE, -1]), 512, kernel_regularizer = l2_regularizer, kernel_initializer = xavier)\n",
    "        feature = tf.layers.dense(tf.reshape(network,[dims[0], np.prod(dims[1:])]), embedding_size, kernel_regularizer = l2_regularizer, kernel_initializer = xavier)\n",
    "    return feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 7, 6, 512]\n"
     ]
    }
   ],
   "source": [
    "# inference\n",
    "tf.reset_default_graph()\n",
    "#image = tf.random_normal([1,112,96,3])\n",
    "#image = tf.constant(np.random.normal(size=[1,112,96,3]),dtype=tf.float32)\n",
    "img_init = np.random.normal(size=[1,112,96,3])\n",
    "image = tf.constant(img_init,dtype=tf.float32)\n",
    "with tf.variable_scope('sphere_net'):\n",
    "    feature = infer(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<tf.Variable 'sphere_net/conv1_/conv1/conv2d/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>\n"
     ]
    }
   ],
   "source": [
    "vars = tf.global_variables()\n",
    "print vars[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 1)\n"
     ]
    }
   ],
   "source": [
    "# loss\n",
    "pred = tf.layers.dense(feature,1,name='pred')\n",
    "print pred.get_shape()\n",
    "loss = tf.nn.l2_loss(pred-1)\n",
    "optimizer = tf.train.GradientDescentOptimizer(0.0001).minimize(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.45344 [[ 0.04769798]]\n",
      "[[ 0.06438893]]\n",
      "0.0117716 [[ 0.84656197]]\n",
      "[[ 0.84942561]]\n",
      "0.000257565 [[ 0.9773035]]\n",
      "[[ 0.97773653]]\n",
      "5.47655e-06 [[ 0.99669045]]\n",
      "[[ 0.99675322]]\n",
      "1.71822e-07 [[ 0.99941379]]\n",
      "[[ 0.99942195]]\n"
     ]
    }
   ],
   "source": [
    "# train and save model\n",
    "saver = tf.train.Saver()\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    for i in xrange(500):\n",
    "        loss_np,pred_np,_ = sess.run([loss,pred,optimizer])\n",
    "        if i % 100 ==0:\n",
    "            print loss_np,pred_np      \n",
    "            saver.save(sess,'model/sphere',global_step=i,write_meta_graph=False)\n",
    "            pred_nd = sess.run(pred)\n",
    "            print(pred_nd)\n",
    "      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'tensorflow.python.training.checkpoint_state_pb2.CheckpointState'>\n",
      "model/sphere-400\n",
      "Tensor(\"Const:0\", shape=(1, 112, 96, 3), dtype=float32)\n",
      "[1, 7, 6, 512]\n",
      "INFO:tensorflow:Restoring parameters from model/sphere-400\n",
      "[[ 0.99942195]]\n"
     ]
    }
   ],
   "source": [
    "# restore the model and test\n",
    "tf.reset_default_graph()\n",
    "ckpt = tf.train.get_checkpoint_state('model')\n",
    "\n",
    "print type(ckpt)\n",
    "print ckpt.model_checkpoint_path\n",
    "image = tf.constant(img_init,dtype=tf.float32)\n",
    "print image\n",
    "#image = tf.random_normal([1,112,96,3])\n",
    "#image = tf.constant(np.random.normal(size=[1,112,96,3]),dtype=tf.float32)\n",
    "with tf.variable_scope('sphere_net'):\n",
    "    feature = infer(image)\n",
    "pred = tf.layers.dense(feature,1,name='pred')\n",
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    saver.restore(sess,ckpt.model_checkpoint_path)\n",
    "    #saver.restore(sess,'model/sphere-100')\n",
    "    print(sess.run(pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Customized model initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 7, 6, 512]\n",
      "[1, 7, 6, 512]\n",
      "[<tf.Variable 'pred/kernel:0' shape=(512, 1) dtype=float32_ref>, <tf.Variable 'pred/bias:0' shape=(1,) dtype=float32_ref>]\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_67/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_67/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv1_/conv1/conv1/alpha from checkpoint model/sphere-100 with sphere_net/conv1_/conv1/conv1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv1_/conv1_23/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv1_/conv1_23/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/feature/dense/kernel from checkpoint model/sphere-100 with sphere_net/feature/dense/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3/conv2d/bias from checkpoint model/sphere-100 with sphere_net/conv3_/conv3/conv2d/bias\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_89/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_89/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv4_/conv4_23/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv4_/conv4_23/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2_23/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_23/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_23/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_23/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_45/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_45/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2_45/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_45/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_45/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_45/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_89/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_89/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2_45/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_45/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_89/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_89/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_89/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_89/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv1_/conv1/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv1_/conv1/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_45/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_45/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_67/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_67/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv4_/conv4/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv4_/conv4/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_23/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_23/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2_45/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_45/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_67/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_67/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv4_/conv4/conv4/alpha from checkpoint model/sphere-100 with sphere_net/conv4_/conv4/conv4/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv4_/conv4/conv2d/bias from checkpoint model/sphere-100 with sphere_net/conv4_/conv4/conv2d/bias\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv4_/conv4_23/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv4_/conv4_23/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_67/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_67/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv4_/conv4_23/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv4_/conv4_23/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2_23/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_23/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2/conv2/alpha from checkpoint model/sphere-100 with sphere_net/conv2_/conv2/conv2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2_45/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_45/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_23/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_23/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv1_/conv1/conv2d/bias from checkpoint model/sphere-100 with sphere_net/conv1_/conv1/conv2d/bias\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2_23/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_23/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv4_/conv4_23/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv4_/conv4_23/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv1_/conv1_23/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv1_/conv1_23/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv1_/conv1_23/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv1_/conv1_23/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv1_/conv1_23/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv1_/conv1_23/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/feature/dense/bias from checkpoint model/sphere-100 with sphere_net/feature/dense/bias\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2_23/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_23/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2/conv2d/bias from checkpoint model/sphere-100 with sphere_net/conv2_/conv2/conv2d/bias\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_23/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_23/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3/conv3/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3/conv3/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv2_/conv2/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv2_/conv2/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_left/conv3_/conv3_45/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_45/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv1_/conv1/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv1_/conv1/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv4_/conv4/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv4_/conv4/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv4_/conv4_23/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv4_/conv4_23/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_89/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_89/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2/conv2d/bias from checkpoint model/sphere-100 with sphere_net/conv2_/conv2/conv2d/bias\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_23/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_23/conv2d_1/kernel\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_45/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_45/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_45/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_45/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_67/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_67/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_89/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_89/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv4_/conv4_23/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv4_/conv4_23/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2_23/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_23/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_45/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_45/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3/conv2d/bias from checkpoint model/sphere-100 with sphere_net/conv3_/conv3/conv2d/bias\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/feature/dense/bias from checkpoint model/sphere-100 with sphere_net/feature/dense/bias\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2/conv2/alpha from checkpoint model/sphere-100 with sphere_net/conv2_/conv2/conv2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_67/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_67/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_23/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_23/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_45/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_45/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_89/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_89/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv4_/conv4/conv2d/bias from checkpoint model/sphere-100 with sphere_net/conv4_/conv4/conv2d/bias\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_67/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_67/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv4_/conv4_23/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv4_/conv4_23/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3/conv3/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3/conv3/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2_23/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_23/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2_23/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_23/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_23/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_23/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2_45/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_45/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv1_/conv1_23/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv1_/conv1_23/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv1_/conv1_23/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv1_/conv1_23/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv4_/conv4_23/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv4_/conv4_23/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2_45/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_45/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv1_/conv1/conv1/alpha from checkpoint model/sphere-100 with sphere_net/conv1_/conv1/conv1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2_45/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_45/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/feature/dense/kernel from checkpoint model/sphere-100 with sphere_net/feature/dense/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_67/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_67/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv2_/conv2/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv1_/conv1_23/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv1_/conv1_23/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_23/name2/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_23/name2/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv1_/conv1/conv2d/bias from checkpoint model/sphere-100 with sphere_net/conv1_/conv1/conv2d/bias\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv1_/conv1_23/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv1_/conv1_23/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2_23/conv2d/kernel from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_23/conv2d/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv2_/conv2_45/conv2d_1/kernel from checkpoint model/sphere-100 with sphere_net/conv2_/conv2_45/conv2d_1/kernel\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv3_/conv3_89/name1/alpha from checkpoint model/sphere-100 with sphere_net/conv3_/conv3_89/name1/alpha\n",
      "INFO:tensorflow:Initialize variable sphere_net_right/conv4_/conv4/conv4/alpha from checkpoint model/sphere-100 with sphere_net/conv4_/conv4/conv4/alpha\n",
      "INFO:tensorflow:Initialize variable pred/kernel from checkpoint model/sphere-100 with pred/kernel\n",
      "INFO:tensorflow:Initialize variable pred/bias from checkpoint model/sphere-100 with pred/bias\n",
      "[[ 0.84942561]]\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "from tensorflow.python.training import checkpoint_utils\n",
    "image = tf.constant(img_init,dtype=tf.float32)\n",
    "with tf.variable_scope('sphere_net_left'):\n",
    "    feature_left = infer(image)\n",
    "with tf.variable_scope('sphere_net_right'):\n",
    "    feature_right = infer(image)\n",
    "pred = tf.layers.dense(feature_left,1,name='pred')\n",
    "dense_var = [var for var in tf.trainable_variables() if 'pred' in var.name]\n",
    "print dense_var\n",
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "checkpoint_utils.init_from_checkpoint('model/sphere-100',{'sphere_net/':'sphere_net_left/'})\n",
    "checkpoint_utils.init_from_checkpoint('model/sphere-100',{'sphere_net/':'sphere_net_right/'})\n",
    "checkpoint_utils.init_from_checkpoint('model/sphere-100',{'pred/':'pred/'})\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run(session=sess)\n",
    "    #saver.restore(sess,ckpt.model_checkpoint_path)\n",
    "    #saver.restore(sess,'model/sphere-100')\n",
    "    print(sess.run(pred))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetune your model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 7, 6, 512]\n",
      "[<tf.Variable 'pred/kernel:0' shape=(512, 1) dtype=float32_ref>, <tf.Variable 'pred/bias:0' shape=(1,) dtype=float32_ref>]\n",
      "INFO:tensorflow:Restoring parameters from model/sphere-100\n",
      "0.0113363 [[ 0.84942561]]\n",
      "[[ 0.8499887]]\n",
      "0.00535868 [[ 0.89647532]]\n",
      "[[ 0.89686245]]\n",
      "0.00253303 [[ 0.92882377]]\n",
      "[[ 0.92908984]]\n",
      "0.00119736 [[ 0.95106411]]\n",
      "[[ 0.95124698]]\n",
      "0.000565992 [[ 0.96635503]]\n",
      "[[ 0.96648079]]\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "from tensorflow.python.training import checkpoint_utils\n",
    "image = tf.constant(img_init,dtype=tf.float32)\n",
    "with tf.variable_scope('sphere_net'):\n",
    "    feature_left = infer(image)\n",
    "#with tf.variable_scope('sphere_net_right'):\n",
    "#    feature_right = infer(image)\n",
    "pred = tf.layers.dense(feature_left,1,name='pred')\n",
    "loss = tf.nn.l2_loss(pred-1)\n",
    "#optimizer = tf.train.GradientDescentOptimizer(0.0001).minimize(loss)\n",
    "optimizer = tf.train.GradientDescentOptimizer(0.001)\n",
    "dense_var = [var for var in tf.trainable_variables() if 'pred' in var.name]\n",
    "grads = optimizer.compute_gradients(loss,dense_var)\n",
    "train_op = optimizer.apply_gradients(grads)\n",
    "\n",
    "print dense_var\n",
    "saver_vars = [var for var in tf.trainable_variables() if 'pred' not in var.name]\n",
    "#saver = tf.train.Saver(saver_vars)\n",
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "#checkpoint_utils.init_from_checkpoint('model/sphere-100',{'sphere_net/':'sphere_net_left/'})\n",
    "#checkpoint_utils.init_from_checkpoint('model/sphere-100',{'sphere_net/':'sphere_net_right/'})\n",
    "#checkpoint_utils.init_from_checkpoint('model/sphere-100',{'pred/':'pred/'})\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run(session=sess)\n",
    "    #saver.restore(sess,ckpt.model_checkpoint_path)\n",
    "    saver.restore(sess,'model/sphere-100')\n",
    "    for i in xrange(500):\n",
    "        loss_np,pred_np,_ = sess.run([loss,pred,train_op])\n",
    "        if i % 100 ==0:\n",
    "            print loss_np,pred_np      \n",
    "            #saver.save(sess,'model/sphere',global_step=i,write_meta_graph=False)\n",
    "            pred_nd = sess.run(pred)\n",
    "            print(pred_nd)\n",
    "      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
